{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "This part of code is the Q learning brain, which is a brain of the agent.\n",
    "All decisions are made in here.\n",
    "View more on my tutorial page: https://morvanzhou.github.io/tutorials/\n",
    "\"\"\"\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "\n",
    "class QLearningTable:\n",
    "    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):\n",
    "        self.actions = actions  # a list\n",
    "        self.lr = learning_rate\n",
    "        self.gamma = reward_decay\n",
    "        self.epsilon = e_greedy\n",
    "        self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)\n",
    "\n",
    "    def choose_action(self, observation):\n",
    "        self.check_state_exist(observation)\n",
    "        # action selection\n",
    "        if np.random.uniform() < self.epsilon:\n",
    "            # choose best action\n",
    "            state_action = self.q_table.loc[observation, :]\n",
    "            # some actions may have the same value, randomly choose on in these actions\n",
    "            action = np.random.choice(state_action[state_action == np.max(state_action)].index)\n",
    "        else:\n",
    "            # choose random action\n",
    "            action = np.random.choice(self.actions)\n",
    "        return action\n",
    "\n",
    "    def learn(self, s, a, r, s_):\n",
    "        self.check_state_exist(s_)\n",
    "        q_predict = self.q_table.loc[s, a]\n",
    "        if s_ != 'terminal':\n",
    "            q_target = r + self.gamma * self.q_table.loc[s_, :].max()  # next state is not terminal\n",
    "        else:\n",
    "            q_target = r  # next state is terminal\n",
    "        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # update\n",
    "\n",
    "    def check_state_exist(self, state):\n",
    "        if state not in self.q_table.index:\n",
    "            # append new state to q table\n",
    "            self.q_table = self.q_table.append(\n",
    "                pd.Series(\n",
    "                    [0]*len(self.actions),\n",
    "                    index=self.q_table.columns,\n",
    "                    name=state,\n",
    "                )\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
